{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {meth}`torch.Tensor.repeat` 转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend\n"
     ]
    }
   ],
   "source": [
    "%cd ../../..\n",
    "import set_env\n",
    "from d2py.utils.file import mkdir\n",
    "temp_dir = \".temp\"\n",
    "mkdir(temp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exported graph: graph(%data : Float(3, 1, 4, strides=[4, 4, 1], requires_grad=0, device=cpu)):\n",
      "  %onnx::Tile_1 : Long(3, strides=[1], device=cpu) = onnx::Constant[value= 3  3  4 [ CPULongType{3} ]]()\n",
      "  %/Constant_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={3}, onnx_name=\"/Constant\"](), scope: __main__.Model:: # /tmp/ipykernel_1020703/3833725313.py:6:0\n",
      "  %/ConstantOfShape_output_0 : Long(3, strides=[1], device=cpu) = onnx::ConstantOfShape[value={1}, onnx_name=\"/ConstantOfShape\"](%/Constant_output_0), scope: __main__.Model:: # /tmp/ipykernel_1020703/3833725313.py:6:0\n",
      "  %/Expand_output_0 : Float(3, 1, 4, strides=[4, 4, 1], device=cpu) = onnx::Expand[onnx_name=\"/Expand\"](%data, %/ConstantOfShape_output_0), scope: __main__.Model:: # /tmp/ipykernel_1020703/3833725313.py:6:0\n",
      "  %output : Float(9, 3, 16, strides=[48, 16, 1], requires_grad=0, device=cpu) = onnx::Tile[onnx_name=\"/Tile\"](%/Expand_output_0, %onnx::Tile_1), scope: __main__.Model:: # /tmp/ipykernel_1020703/3833725313.py:6:0\n",
      "  return (%output)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "class Model(nn.Module):\n",
    "    def forward(self, x):\n",
    "        return x.repeat(3, 3, 4)\n",
    "\n",
    "shape = 3, 1, 4\n",
    "x = torch.rand(*shape)\n",
    "\n",
    "torch_model = Model()\n",
    "# 导出模型\n",
    "output_name = \"repeat\"\n",
    "torch.onnx.export(\n",
    "    torch_model,               # torch 模型\n",
    "    x,                         # 模型输入或者对于多个输入，使用元组\n",
    "    f\"{temp_dir}/{output_name}.onnx\",               # 模型保存的位置（可以是文件或类似文件的对象）\n",
    "    export_params=True,        # 将训练后的参数权重存储在模型文件内\n",
    "    opset_version=17,          # 导出模型的 ONNX 版本\n",
    "    do_constant_folding=True,  # 是否执行常量折叠以进行优化\n",
    "    input_names = ['data'],    # 模型的输入名称\n",
    "    output_names = ['output'], # 模型的输出名称\n",
    "    verbose=True,\n",
    "    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴\n",
    "    #               'output' : {0 : 'batch_size'}}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![模型结构](images/repeat.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# onnx_program = torch.onnx.dynamo_export(torch_model, x)\n",
    "# onnx_program.save(\"test_dynamo.onnx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data: Tensor[(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Expand<span style=\"color: #AA22FF; font-weight: bold\">.</span>data:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">9</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">16</span>i64), float32] {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> broadcast_to(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data, shape<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">1</span>i64, <span style=\"color: #008000\">4</span>i64]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">1</span>i64, <span style=\"color: #008000\">4</span>i64), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Expand:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  tile(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, reps<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">4</span>i64]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">9</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">16</span>i64), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Tile:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import onnx\n",
    "import tvm\n",
    "from tvm import relay\n",
    "onnx_model = onnx.load(f\"{temp_dir}/{output_name}.onnx\")\n",
    "mod, params = relay.frontend.from_onnx(onnx_model, {\"data\": shape}, freeze_params=True)\n",
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    mod = relay.quantize.prerequisite_optimize(mod, params)\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data: Tensor[(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">4</span>), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Expand<span style=\"color: #AA22FF; font-weight: bold\">.</span>data:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">9</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">16</span>i64), float32] {\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> broadcast_to(<span style=\"color: #AA22FF; font-weight: bold\">%</span>data, shape<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">1</span>i64, <span style=\"color: #008000\">4</span>i64]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">1</span>i64, <span style=\"color: #008000\">4</span>i64), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Expand:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  tile(<span style=\"color: #AA22FF; font-weight: bold\">%</span><span style=\"color: #008000\">0</span>, reps<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">4</span>i64]) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">9</span>i64, <span style=\"color: #008000\">3</span>i64, <span style=\"color: #008000\">16</span>i64), float32] span<span style=\"color: #AA22FF; font-weight: bold\">=/</span>Tile:<span style=\"color: #008000\">0</span>:<span style=\"color: #008000\">0</span> <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with tvm.transform.PassContext(opt_level=3):\n",
    "    with relay.quantize.qconfig(\n",
    "        skip_conv_layers=[],\n",
    "        # calibrate_mode=\"kl_divergence\", \n",
    "        weight_scale=\"max\",\n",
    "        round_for_shift=True,\n",
    "        # rounding=\"TONEAREST\", # \"UPWARD\" or \"TONEAREST\"\n",
    "        # calibrate_skip_layers=[],\n",
    "        skip_dense_layer=False,\n",
    "    ):\n",
    "        qmod = relay.quantize.quantize(mod, params)\n",
    "qmod.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
